import pandas as pd
import numpy as np 
# import matplotlib.pyplot as plt

data_path:str 
data_path = "./MNIST/train.csv"

def read_mnist(path:str)->pd.DataFrame:
    df:pd.DataFrame
    df = pd.read_csv(path,dtype=np.float32)
    return df  

def get_data(path:str=data_path):
    df = read_mnist(path)
    labels = np.int32(df.iloc[:,0].values)
    data = np.reshape(df.iloc[:,1:].values,[-1,28,28])
    # plt.imshow(data[13])
    # plt.show()
    return labels,data

eye = np.eye(10)

def label2vec(label):
    return eye[label]

def vec2label(vec):
    if vec.ndim >= 2:
        vec = vec[0]
    return np.argmax(vec)


if __name__=="__main__":
    # labels,data = get_data(data_path)
    # print(labels,data)
    tmp = get_data(data_path)
    print(tmp)